Summary¶

Baseline LSTM model which uses CAISO SP15 pricing and load data to predict real-time market price for the next day. It takes data between 01/01/2023 - 03/01/2025 and predicts the rest of 2025.

Model inputs are:

  • day-ahead market price
  • day-ahead load forecast
  • prior day real-time price
  • day, hour, night/day, year indicators

Output:

  • real-time price deviation from day-ahead price for the entirety of the next 24 hour block

Baseline hyperparameters¶

hidden_size: 64
learning_rate: 0.001
batch_size: 16
dropout: 0.2

CRPS = 6.44

Comments¶

With relatively little tuning this LSTM model outputs reasonable prediction results given the narrow scope of the training data.

Coverage is slightly biased, missing lower range and upper range values more often. But it captures the general daily trend well and provides uncertainty bounds which mostly correspond to reality.

It also completely fails to account for large price spikes which happen too infrequently for the model to accurately predict. It's likely that a second model would need to be trained to specifically identify periods of high price spike probability. This model would need more data including additional nodes for spike examples, and other types of data like generation mix and weather forecasts. That's considered out of scope for this project but its implementation would be relatively straightforward in the project's modular structure.

In [1]:
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly
import plotly.colors as pc
import plotly.graph_objects as go
import yaml

from price_forecasting.config import MODELS_DIR, PROCESSED_DATA_DIR
from price_forecasting.utils.scoring_tools import get_mean_crps

Loading Data¶

In [2]:
MODEL_DIR = MODELS_DIR / 'LSTM_v1'

# load config file and variables from model run
with open(MODEL_DIR / 'config.yaml', 'r') as f:
    config = yaml.safe_load(f)
DATA_SOURCE = PROCESSED_DATA_DIR / config['data_source']
quantiles = np.array(config['quantiles'])

# load X data in DataFrame format
X_test = pd.read_parquet(DATA_SOURCE / 'X_test.pqt')

# load y_test data and format as numpy array
y_test = pd.read_parquet(DATA_SOURCE / 'y_test.pqt').to_numpy()
y_test = y_test.reshape([-1])

# load y predictions from model
y_pred = np.load(MODEL_DIR / 'y_pred.npy')

dam = X_test['DAM_PRC'] #day ahead price data

time = X_test.index
time = time.tz_convert('US/Pacific')

Plotting¶

In [3]:
plotly.offline.init_notebook_mode()

def quantile_plot():
    fig = go.Figure()

    n = len(quantiles)
    colors = pc.sample_colorscale('Viridis', [i/n for i in range(n)])

    for i, yi in enumerate(y_pred.T):
        if i == 0: 
            fill = None
        else:
            fill = 'tonexty'
        fig.add_trace(go.Scatter(x=time, y=yi + dam, mode='lines', name=quantiles[i], 
                                 line=dict(width=1,color=colors[i]), fill=fill))

    fig.add_trace(go.Scatter(x=time, y=y_test + dam, mode='lines', name='RTM Price',
                              line=dict(width=2, color='black', dash='dot')))

    start = datetime.fromisoformat('2025-03-01 00:00:00')
    end = datetime.fromisoformat('2025-03-08 00:00:00')

    fig.update_layout(title='SP15 RTM Price Prediction',
                      xaxis_title='Time',
                      width=1100, 
                      height=500,
                      yaxis_title='Price ($/MWh)',
                      xaxis=dict(range=[start, end]),   
                      yaxis=dict(range=[-100, 250]), 
                     )

    fig.show()

quantile_plot()

Model Quantification¶

CRPS¶

In [4]:
crps = get_mean_crps(y_pred, y_test, quantiles)
print(f"CRPS: {crps:.{3}}")
CRPS: 6.44

Coverage¶

In [5]:
def coverage_calibration_plot():
    coverage = []
    for i, q in enumerate(quantiles):
        below = (y_test <= y_pred[:,i]).mean()
        coverage.append(below)

    fig = go.Figure()

    fig.add_trace(go.Scatter(x=quantiles, y=coverage-quantiles, mode="markers", name='Measured',
                              marker=dict(size=8, color='blue')))

    fig.add_trace(go.Scatter(x=[0,1], y=[0,0], mode='lines', name='Expected',
                              line=dict(width=2, color='black', dash='dot')))

    fig.update_layout(title='Coverage Calibration',
                      xaxis_title='Quantile',
                      width=650, 
                      height=500,
                      yaxis_title='Coverage Bias',
                      yaxis=dict(range=[-0.05, 0.05]), 
                      xaxis=dict(range=[0.0, 1.0]), 
                     )

    fig.show()

coverage_calibration_plot()
In [6]:
def interval_violation_plot():
    fig = go.Figure()

    q_hi = 0.99
    q_low = 0.01

    i_low = np.where(quantiles == q_low)[0][0]
    y_pred_low = y_pred[:,i_low]
    under = (y_test < y_pred_low).astype(float) * -200

    i_hi = np.where(quantiles == q_hi)[0][0]
    y_pred_hi = y_pred[:,i_hi]
    over = (y_test > y_pred_hi).astype(float) * 200

    fig.add_trace(go.Scatter(x=time, y=y_pred_low + dam, mode='lines', name=q_low, 
                            line=dict(width=1,color="gray"), fill=None))


    fig.add_trace(go.Scatter(x=time, y=y_pred_hi + dam, mode='lines', name=q_hi, 
                            line=dict(width=1,color="gray"), fill="tonexty"))

    fig.add_trace(go.Scatter(x=time, y=y_test + dam, mode='lines', name='RTM Price',
                          line=dict(width=2, color='black', dash='dot')))

    fig.add_trace(go.Scatter(x=time, y=under, mode='lines', name='Under', 
                            line=dict(width=0,color="blue"), fill="tozeroy"))

    fig.add_trace(go.Scatter(x=time, y=over, mode='lines', name='Over', 
                            line=dict(width=0,color="red"), fill="tozeroy"))


    start = datetime.fromisoformat('2025-03-01 00:00:00')
    end = datetime.fromisoformat('2025-03-08 00:00:00')

    fig.update_layout(title='1-99% Interval Violation Plot',
                      xaxis_title='Time',
                      width=1000, 
                      height=500,
                      yaxis_title='Price ($/MWh)',
                      xaxis=dict(range=[start, end]),   
                      yaxis=dict(range=[-100, 250]), 
                     )

    fig.show()
interval_violation_plot()

Expected Value Residuals¶

In [7]:
def residual_plot():
    evs = [] #expected value from quantiles
    for pred in y_pred:
        ev = np.trapezoid(pred, quantiles)
        evs.append(ev)
    evs = np.array(evs)

    residual = y_test - evs

    fig = go.Figure()

    fig.add_trace(go.Scatter(x=y_test, y=residual, mode="markers", name='Measured',
                              marker=dict(size=3, color='blue')))
    fig.add_trace(go.Scatter(x=[-100,1000], y=[0,0], mode='lines', name='Expected',
                              line=dict(width=2, color='black', dash='dot')))

    fig.update_layout(title='Expected Value Residuals',
                  xaxis_title='True Value',
                  width=1000, 
                  height=500,
                  yaxis_title='Residual',
                  xaxis=dict(range=[-100, 250]),   
                  yaxis=dict(range=[-100, 250]), 
                 )
    fig.show()
residual_plot()

Probability Integral Transform¶

In [8]:
PIT = []
for val, pred in zip(y_test, y_pred):
    p = np.interp(val, pred, quantiles, left=0.0, right=1.0)
    PIT.append(p)

plt.hist(PIT, bins=4, density=True,)
plt.axhline(1, color='k', linestyle='--')
plt.xlabel('Quartile Bin')
plt.ylabel('Probability Density')
plt.title("Probability Integral Transform")
plt.show()
No description has been provided for this image

Quantile Spread¶

In [9]:
# Measure prediction spread
def spread_plot():
    fig = go.Figure()

    q_hi = 0.95
    q_low = 0.05

    i_low = np.where(quantiles == q_low)[0][0]
    y_pred_low = y_pred[:,i_low]

    i_hi = np.where(quantiles == q_hi)[0][0]
    y_pred_hi = y_pred[:,i_hi]

    spread = y_pred_hi - y_pred_low
    avg_width = spread.mean()
    print(f'Average 5-95% Prediction Width: {avg_width:.{3}} $/MWh')

    fig.add_trace(go.Scatter(x=time, y=spread, mode='lines', name="5-95 Spread", 
                            line=dict(width=2,color="blue"), fill=None))

    fig.add_trace(go.Scatter(x=time, y=y_test + dam, mode='lines', name='RTM Price',
                          line=dict(width=1, color='black', dash='dot')))


    start = datetime.fromisoformat('2025-03-01 00:00:00')
    end = datetime.fromisoformat('2025-03-08 00:00:00')

    fig.update_layout(title='5-95% Spread Measure',
                      xaxis_title='Time',
                      width=1000, 
                      height=500,
                      yaxis_title='Price ($/MWh)',
                      xaxis=dict(range=[start, end]),   
                      yaxis=dict(range=[-100, 250]), 
                     )
    fig.show()

    plt.hist(spread, density=True,)
    plt.xlabel('Spread ($/MWh)')
    plt.ylabel('Probability Density')
    plt.title("Spread Distribution")
    plt.show()

spread_plot()
Average 5-95% Prediction Width: 33.4 $/MWh
No description has been provided for this image
In [ ]: